{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Zfjv9nzQSyZ0"
   },
   "outputs": [],
   "source": [
    "!pip install -q -U transformers peft torch accelerate einops sentencepiece bitsandbytes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "H-6hb4ZQTUkz"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from peft import PeftModel, PeftConfig\n",
    "from transformers import (\n",
    "    AutoModelForCausalLM,\n",
    "    AutoTokenizer,\n",
    "    BitsAndBytesConfig,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 49,
     "referenced_widgets": [
      "ff857b2fcf364766b5e8662b35f2ceb3",
      "1f1aa816908f413c87ad0ad7c0c2f253",
      "aa9cc1fbe4bc4f039848a72311912c9c",
      "baf56fb031aa4960ae69e2b5e006e550",
      "bc88ca712eb042d8becc601f2906dda2",
      "30465b429a854da9a70ff325b028caf9",
      "c22f4e9e6c0a4ec2893188fa1084f423",
      "b2d1dc24fabc4b218edb0f80ae7c77eb",
      "0f8a0a5ef8334f5cbea6de99b365da68",
      "2b46ffd07f7346ecb03289fd96be0931",
      "1f3a5707f83041e48aa6b2192c1f3468"
     ]
    },
    "id": "G8mb8K3zTWWx",
    "outputId": "1018e6f6-59ac-42b3-d191-96e33d0ffadd"
   },
   "outputs": [],
   "source": [
    "peft_model_id = \"dfurman/Mixtral-8x7B-peft-v0.1\"\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    peft_model_id,\n",
    "    use_fast=True,\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
    ")\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    config.base_model_name_or_path,\n",
    "    quantization_config=bnb_config,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    device_map=\"auto\",\n",
    "    trust_remote_code=True,\n",
    ")\n",
    "\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "2cnoftm6TYXM",
    "outputId": "4ec10bbe-7b01-49ab-997a-5e7e03bb81c5"
   },
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\"role\": \"user\", \"content\": \"Tell me a recipe for a mai tai.\"},\n",
    "]\n",
    "\n",
    "print(\"\\n\\n*** Prompt:\")\n",
    "input_ids = tokenizer.apply_chat_template(\n",
    "    messages,\n",
    "    tokenize=True,\n",
    "    return_tensors=\"pt\",\n",
    ")\n",
    "print(tokenizer.decode(input_ids[0]))\n",
    "\n",
    "print(\"\\n\\n*** Generate:\")\n",
    "with torch.autocast(\"cuda\", dtype=torch.bfloat16):\n",
    "    output = model.generate(\n",
    "        input_ids=input_ids.to(\"cuda\"),\n",
    "        max_new_tokens=1024,\n",
    "        return_dict_in_generate=True,\n",
    "    )\n",
    "\n",
    "response = tokenizer.decode(\n",
    "    output[\"sequences\"][0][len(input_ids[0]) :], skip_special_tokens=True\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NOTURK6nXoYZ",
    "outputId": "a571c310-9d1f-4381-a844-86663c836b77"
   },
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"Recommend some games to play for 3 year old and 7 year olds.\",\n",
    "    },\n",
    "]\n",
    "\n",
    "print(\"\\n\\n*** Prompt:\")\n",
    "input_ids = tokenizer.apply_chat_template(\n",
    "    messages,\n",
    "    tokenize=True,\n",
    "    return_tensors=\"pt\",\n",
    ")\n",
    "print(tokenizer.decode(input_ids[0]))\n",
    "\n",
    "print(\"\\n\\n*** Generate:\")\n",
    "with torch.autocast(\"cuda\", dtype=torch.bfloat16):\n",
    "    output = model.generate(\n",
    "        input_ids=input_ids.to(\"cuda\"),\n",
    "        max_new_tokens=1024,\n",
    "        return_dict_in_generate=True,\n",
    "    )\n",
    "\n",
    "response = tokenizer.decode(\n",
    "    output[\"sequences\"][0][len(input_ids[0]) :], skip_special_tokens=True\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8Bw1lhjhYhw1"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "0f8a0a5ef8334f5cbea6de99b365da68": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "1f1aa816908f413c87ad0ad7c0c2f253": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_30465b429a854da9a70ff325b028caf9",
      "placeholder": "​",
      "style": "IPY_MODEL_c22f4e9e6c0a4ec2893188fa1084f423",
      "value": "Loading checkpoint shards: 100%"
     }
    },
    "1f3a5707f83041e48aa6b2192c1f3468": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "2b46ffd07f7346ecb03289fd96be0931": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "30465b429a854da9a70ff325b028caf9": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "aa9cc1fbe4bc4f039848a72311912c9c": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_b2d1dc24fabc4b218edb0f80ae7c77eb",
      "max": 19,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_0f8a0a5ef8334f5cbea6de99b365da68",
      "value": 19
     }
    },
    "b2d1dc24fabc4b218edb0f80ae7c77eb": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "baf56fb031aa4960ae69e2b5e006e550": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_2b46ffd07f7346ecb03289fd96be0931",
      "placeholder": "​",
      "style": "IPY_MODEL_1f3a5707f83041e48aa6b2192c1f3468",
      "value": " 19/19 [04:13&lt;00:00, 18.79s/it]"
     }
    },
    "bc88ca712eb042d8becc601f2906dda2": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c22f4e9e6c0a4ec2893188fa1084f423": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "ff857b2fcf364766b5e8662b35f2ceb3": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_1f1aa816908f413c87ad0ad7c0c2f253",
       "IPY_MODEL_aa9cc1fbe4bc4f039848a72311912c9c",
       "IPY_MODEL_baf56fb031aa4960ae69e2b5e006e550"
      ],
      "layout": "IPY_MODEL_bc88ca712eb042d8becc601f2906dda2"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
